{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "32e022a2",
   "metadata": {},
   "source": [
    "# Program-aided language model (PAL) chain\n",
    "\n",
    "Implements Program-Aided Language Models, as in https://arxiv.org/pdf/2211.10435.pdf.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1370e40f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_experimental.pal_chain import PALChain\n",
    "from langchain import OpenAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a58e15e",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = OpenAI(temperature=0, max_tokens=512)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "095adc76",
   "metadata": {},
   "source": [
    "## Math Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "beddcac7",
   "metadata": {},
   "outputs": [],
   "source": [
    "pal_chain = PALChain.from_math_prompt(llm, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e2eab9d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "question = \"Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy. If Cindy has four pets, how many total pets do the three have?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3ef64b27",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new PALChain chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mdef solution():\n",
      "    \"\"\"Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy. If Cindy has four pets, how many total pets do the three have?\"\"\"\n",
      "    cindy_pets = 4\n",
      "    marcia_pets = cindy_pets + 2\n",
      "    jan_pets = marcia_pets * 3\n",
      "    total_pets = cindy_pets + marcia_pets + jan_pets\n",
      "    result = total_pets\n",
      "    return result\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'28'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pal_chain.run(question)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0269d20a",
   "metadata": {},
   "source": [
    "## Colored Objects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e524f81f",
   "metadata": {},
   "outputs": [],
   "source": [
    "pal_chain = PALChain.from_colored_object_prompt(llm, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "03a237b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "question = \"On the desk, you see two blue booklets, two purple booklets, and two yellow pairs of sunglasses. If I remove all the pairs of sunglasses from the desk, how many purple items remain on it?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a84a4352",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new PALChain chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3m# Put objects into a list to record ordering\n",
      "objects = []\n",
      "objects += [('booklet', 'blue')] * 2\n",
      "objects += [('booklet', 'purple')] * 2\n",
      "objects += [('sunglasses', 'yellow')] * 2\n",
      "\n",
      "# Remove all pairs of sunglasses\n",
      "objects = [object for object in objects if object[0] != 'sunglasses']\n",
      "\n",
      "# Count number of purple objects\n",
      "num_purple = len([object for object in objects if object[1] == 'purple'])\n",
      "answer = num_purple\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished PALChain chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'2'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pal_chain.run(question)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc3d7f10",
   "metadata": {},
   "source": [
    "## Intermediate Steps\n",
    "You can also use the intermediate steps flag to return the code executed that generates the answer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9d2d9c61",
   "metadata": {},
   "outputs": [],
   "source": [
    "pal_chain = PALChain.from_colored_object_prompt(\n",
    "    llm, verbose=True, return_intermediate_steps=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b29b971b",
   "metadata": {},
   "outputs": [],
   "source": [
    "question = \"On the desk, you see two blue booklets, two purple booklets, and two yellow pairs of sunglasses. If I remove all the pairs of sunglasses from the desk, how many purple items remain on it?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a2c40c28",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new PALChain chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3m# Put objects into a list to record ordering\n",
      "objects = []\n",
      "objects += [('booklet', 'blue')] * 2\n",
      "objects += [('booklet', 'purple')] * 2\n",
      "objects += [('sunglasses', 'yellow')] * 2\n",
      "\n",
      "# Remove all pairs of sunglasses\n",
      "objects = [object for object in objects if object[0] != 'sunglasses']\n",
      "\n",
      "# Count number of purple objects\n",
      "num_purple = len([object for object in objects if object[1] == 'purple'])\n",
      "answer = num_purple\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "result = pal_chain({\"question\": question})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "efddd033",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"# Put objects into a list to record ordering\\nobjects = []\\nobjects += [('booklet', 'blue')] * 2\\nobjects += [('booklet', 'purple')] * 2\\nobjects += [('sunglasses', 'yellow')] * 2\\n\\n# Remove all pairs of sunglasses\\nobjects = [object for object in objects if object[0] != 'sunglasses']\\n\\n# Count number of purple objects\\nnum_purple = len([object for object in objects if object[1] == 'purple'])\\nanswer = num_purple\""
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result[\"intermediate_steps\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfd88594",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
